import os
import time
import functools
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import ipdb

from torch.utils.cpp_extension import load
parent_dir = os.path.dirname(os.path.abspath(__file__))
render_utils_cuda = load(
        name='render_utils_cuda',
        sources=[
            os.path.join(parent_dir, path)
            for path in ['cuda/render_utils.cpp', 'cuda/render_utils_kernel.cu']],
        verbose=True)

total_variation_cuda = load(
        name='total_variation_cuda',
        sources=[
            os.path.join(parent_dir, path)
            for path in ['cuda/total_variation.cpp', 'cuda/total_variation_kernel.cu']],
        verbose=True)


def create_grid(type, **kwargs):
    if type == 'DenseGrid':
        return DenseGrid(**kwargs)
    elif type == 'TensoRFGrid':
        return TensoRFGrid(**kwargs)
    elif type == 'PlaneGrid':
        return PlaneGrid(**kwargs)
    else:
        raise NotImplementedError


''' Dense 3D grid
'''
class DenseGrid(nn.Module):
    def __init__(self, channels, world_size, xyz_min, xyz_max, residual_mode=False, **kwargs):
        super(DenseGrid, self).__init__()
        residual_mode = False
        self.residual_mode = residual_mode
        self.channels = channels
        self.world_size = world_size
        self.register_buffer('xyz_min', torch.Tensor(xyz_min))
        self.register_buffer('xyz_max', torch.Tensor(xyz_max))
        self.grid = nn.Parameter(torch.zeros([1, channels, *world_size]))

        if residual_mode:
            self.grid_residual = nn.Parameter(torch.zeros([1, channels, *world_size]))
        if residual_mode:
            print("3DGrid version activated !!!!!! residual:",residual_mode)

    def commit_residual(self):
        if self.residual_mode:
            self.grid = nn.Parameter(self.grid.data+self.grid_residual.data)
            self.grid_residual = nn.Parameter(torch.zeros_like(self.grid_residual))

    def forward(self, xyz,dir=None):
        '''
        xyz: global coordinates to query
        '''
        shape = xyz.shape[:-1]
        xyz = xyz.reshape(1,1,1,-1,3)
        ind_norm = ((xyz - self.xyz_min) / (self.xyz_max - self.xyz_min)).flip((-1,)) * 2 - 1
        
        if self.residual_mode:
            out = F.grid_sample(self.grid.detach()+self.grid_residual, ind_norm, mode='bilinear', align_corners=True)
            out = out.reshape(self.channels,-1).T.reshape(*shape,self.channels)
        else:
            out_o = F.grid_sample(self.grid, ind_norm, mode='bilinear', align_corners=True)
            out = out_o.reshape(self.channels,-1).T.reshape(*shape,self.channels)
        if self.channels == 1:
            out = out.squeeze(-1)
        return out

    def residual_loss_l1(self):
        return torch.norm(self.grid_residual,p=1)/self.grid_residual.numel()


    def scale_volume_grid(self, new_world_size):
        if self.channels == 0:
            self.grid = nn.Parameter(torch.zeros([1, self.channels, *new_world_size]))
        else:
            if  self.residual_mode:
                self.grid = nn.Parameter(
                    F.interpolate(self.grid_residual.data, size=tuple(new_world_size), mode='trilinear', align_corners=True))
            else:
                self.grid = nn.Parameter(
                    F.interpolate(self.grid.data, size=tuple(new_world_size), mode='trilinear', align_corners=True))

    def scale_volume_grid_value(self, new_world_size):
        grid = None
        if self.channels == 0:
            grid = nn.Parameter(torch.zeros([1, self.channels, *new_world_size]))
        else:
           
            grid = nn.Parameter(F.interpolate(self.grid.data, size=tuple(new_world_size), mode='trilinear', align_corners=True))
        
        return grid

    def total_variation_add_grad(self, wx, wy, wz, dense_mode):
        '''Add gradients by total variation loss in-place'''
        if not self.residual_mode:
            total_variation_cuda.total_variation_add_grad(
                self.grid, self.grid.grad, wx, wy, wz, dense_mode)
        else:
            total_variation_cuda.total_variation_add_grad(
                self.grid_residual, self.grid_residual.grad, wx, wy, wz, dense_mode)

    def get_dense_grid(self):
        return self.grid

    @torch.no_grad()
    def __isub__(self, val):
        self.grid.data -= val
        return self

    def extra_repr(self):
        return f'channels={self.channels}, world_size={self.world_size.tolist()}'






class PlaneGrid(nn.Module):
    def __init__(self, channels, world_size, xyz_min, xyz_max,  config, residual_mode = False):
        super(PlaneGrid, self).__init__()
        if 'factor' in config:
            self.scale = config['factor']
        else:
            self.scale = 2
            
        self.channels = channels
        self.world_size = world_size
        self.config = config
        self.residual_mode = residual_mode
        self.register_buffer('xyz_min', torch.Tensor(xyz_min))
        self.register_buffer('xyz_max', torch.Tensor(xyz_max))
        X, Y, Z = world_size
        X = X*self.scale
        Y = Y*self.scale
        Z = Z*self.scale
        self.world_size = torch.tensor([X,Y,Z])
        R = self.channels //3
        Rxy = R
        self.xy_plane = nn.Parameter(torch.randn([1, Rxy, X, Y ]) * 0.1)
        self.xz_plane = nn.Parameter(torch.randn([1, R,  X, Z]) * 0.1)
        self.yz_plane = nn.Parameter(torch.randn([1, R,  Y, Z]) * 0.1)

        if residual_mode:
            self.xy_plane_residual = nn.Parameter(torch.zeros([1, Rxy, X, Y ]))
            self.xz_plane_residual = nn.Parameter(torch.zeros([1, R,  X, Z]))
            self.yz_plane_residual = nn.Parameter(torch.zeros([1, R,  Y, Z]) )

        print("Planes version activated !!!!!! residual:",residual_mode)

    def commit_residual(self):
        self.xy_plane = nn.Parameter(self.xy_plane.data + self.xy_plane_residual.data)
        self.xz_plane = nn.Parameter(self.xz_plane.data + self.xz_plane_residual.data)
        self.yz_plane = nn.Parameter(self.yz_plane.data + self.yz_plane_residual.data)

        self.xy_plane_residual = nn.Parameter(torch.zeros_like(self.xy_plane_residual))
        self.xz_plane_residual = nn.Parameter(torch.zeros_like(self.xz_plane_residual))
        self.yz_plane_residual = nn.Parameter(torch.zeros_like(self.yz_plane_residual))



    def compute_planes_feat(self, ind_norm):
        # Interp feature (feat shape: [n_pts, n_comp])
        if self.residual_mode:
            xy_feat = F.grid_sample(self.xy_plane.detach()+self.xy_plane_residual, ind_norm[:,:,:,[1,0]], mode='bilinear', align_corners=True).flatten(0,2).T
            xz_feat = F.grid_sample(self.xz_plane.detach()+self.xz_plane_residual, ind_norm[:,:,:,[2,0]], mode='bilinear', align_corners=True).flatten(0,2).T
            yz_feat = F.grid_sample(self.yz_plane.detach()+self.yz_plane_residual, ind_norm[:,:,:,[2,1]], mode='bilinear', align_corners=True).flatten(0,2).T

        else:
            xy_feat = F.grid_sample(self.xy_plane, ind_norm[:,:,:,[1,0]], mode='bilinear', align_corners=True).flatten(0,2).T
            xz_feat = F.grid_sample(self.xz_plane, ind_norm[:,:,:,[2,0]], mode='bilinear', align_corners=True).flatten(0,2).T
            yz_feat = F.grid_sample(self.yz_plane, ind_norm[:,:,:,[2,1]], mode='bilinear', align_corners=True).flatten(0,2).T

        # Aggregate components
        feat = torch.cat([
            xy_feat ,
            xz_feat ,
            yz_feat 
        ], dim=-1)



        return feat       

    def forward(self, xyz, dir=None, center=None):
        '''
        xyz: global coordinates to query
        '''
        shape = xyz.shape[:-1]
        xyz = xyz.reshape(1,1,-1,3)
        ind_norm = (xyz - self.xyz_min) / (self.xyz_max - self.xyz_min) * 2 - 1
        ind_norm = torch.cat([ind_norm, torch.zeros_like(ind_norm[...,[0]])], dim=-1)

       
        if self.channels > 1:
            out = self.compute_planes_feat(ind_norm)
            out = out.reshape(*shape,self.channels)
        else:
            raise Exception("no implement!!!!!!!!!!")
        return out

    def scale_volume_grid(self, new_world_size):
        if self.channels == 0:
            return
        X, Y, Z = new_world_size
        X = X*self.scale
        Y=Y*self.scale
        Z = Z*self.scale
        if self.residual_mode:
            self.xy_plane_residual = nn.Parameter(F.interpolate(self.xy_plane_residual.data, size=[X,Y], mode='bilinear', align_corners=True))
            self.xz_plane_residual = nn.Parameter(F.interpolate(self.xz_plane_residual.data, size=[X,Z], mode='bilinear', align_corners=True))
            self.yz_plane_residual = nn.Parameter(F.interpolate(self.yz_plane_residual.data, size=[Y,Z], mode='bilinear', align_corners=True))
        else:
            self.xy_plane = nn.Parameter(F.interpolate(self.xy_plane.data, size=[X,Y], mode='bilinear', align_corners=True))
            self.xz_plane = nn.Parameter(F.interpolate(self.xz_plane.data, size=[X,Z], mode='bilinear', align_corners=True))
            self.yz_plane = nn.Parameter(F.interpolate(self.yz_plane.data, size=[Y,Z], mode='bilinear', align_corners=True))

    def scale_volume_grid_value(self, new_world_size):
        if self.channels == 0:
            return
        X, Y, Z = new_world_size
        X = X*self.scale
        Y=Y*self.scale
        Z = Z*self.scale
    
        xy_plane = nn.Parameter(F.interpolate(self.xy_plane.data, size=[X,Y], mode='bilinear', align_corners=True), requires_grad=False)
        xz_plane = nn.Parameter(F.interpolate(self.xz_plane.data, size=[X,Z], mode='bilinear', align_corners=True), requires_grad=False)
        yz_plane = nn.Parameter(F.interpolate(self.yz_plane.data, size=[Y,Z], mode='bilinear', align_corners=True), requires_grad=False)

        return xy_plane, xz_plane, yz_plane

    def residual_loss_l1(self):
        #ipdb.set_trace()
        return (torch.norm(self.xy_plane_residual,p=1)/self.xy_plane_residual.numel() \
                            + torch.norm(self.xz_plane_residual,p=1)/self.xz_plane_residual.numel()  \
                            + torch.norm(self.yz_plane_residual,p=1)/self.yz_plane_residual.numel())/3.0

    def total_variation_add_grad(self, wx, wy, wz, dense_mode):
        #raise Exception("No implement!!1")
        '''Add gradients by total variation loss in-place'''
        loss = wx * F.smooth_l1_loss(self.xy_plane[:,:,1:], self.xy_plane[:,:,:-1], reduction='sum') +\
               wy * F.smooth_l1_loss(self.xy_plane[:,:,:,1:], self.xy_plane[:,:,:,:-1], reduction='sum') +\
               wx * F.smooth_l1_loss(self.xz_plane[:,:,1:], self.xz_plane[:,:,:-1], reduction='sum') +\
               wz * F.smooth_l1_loss(self.xz_plane[:,:,:,1:], self.xz_plane[:,:,:,:-1], reduction='sum') +\
               wy * F.smooth_l1_loss(self.yz_plane[:,:,1:], self.yz_plane[:,:,:-1], reduction='sum') +\
               wz * F.smooth_l1_loss(self.yz_plane[:,:,:,1:], self.yz_plane[:,:,:,:-1], reduction='sum') 
        loss /= 6
        loss.backward()


    def extra_repr(self):
        return f'channels={self.channels}, world_size={self.world_size.tolist()}, n_comp={self.channels //3}'



''' Vector-Matrix decomposited grid
See TensoRF: Tensorial Radiance Fields (https://arxiv.org/abs/2203.09517)
'''
class TensoRFGrid(nn.Module):
    def __init__(self, channels, world_size, xyz_min, xyz_max, config):
        super(TensoRFGrid, self).__init__()
        self.scale = 4
        self.channels = channels
        self.world_size = world_size
        self.config = config
        self.register_buffer('xyz_min', torch.Tensor(xyz_min))
        self.register_buffer('xyz_max', torch.Tensor(xyz_max))
        X, Y, Z = world_size
        X = X*self.scale
        Y = Y*self.scale
        Z = Z*self.scale
        self.world_size = torch.tensor([X,Y,Z])
        R = self.channels //3
        Rxy = R
        self.xy_plane = nn.Parameter(torch.randn([1, Rxy, X, Y]) * 0.1)
        self.xz_plane = nn.Parameter(torch.randn([1, R, X, Z]) * 0.1)
        self.yz_plane = nn.Parameter(torch.randn([1, R, Y, Z]) * 0.1)
        self.x_vec = nn.Parameter(torch.randn([1, R, X, 1]) * 0.1)
        self.y_vec = nn.Parameter(torch.randn([1, R, Y, 1]) * 0.1)
        self.z_vec = nn.Parameter(torch.randn([1, Rxy, Z, 1]) * 0.1)
        if self.channels > 1:
            self.f_vec = nn.Parameter(torch.ones([R+R+Rxy, channels]))
            nn.init.kaiming_uniform_(self.f_vec, a=np.sqrt(5))

    def forward(self, xyz,dir=None):
        '''
        xyz: global coordinates to query
        '''
        shape = xyz.shape[:-1]
        xyz = xyz.reshape(1,1,-1,3)
        ind_norm = (xyz - self.xyz_min) / (self.xyz_max - self.xyz_min) * 2 - 1
        ind_norm = torch.cat([ind_norm, torch.zeros_like(ind_norm[...,[0]])], dim=-1)
        if self.channels > 1:
            out = compute_tensorf_feat(
                    self.xy_plane, self.xz_plane, self.yz_plane,
                    self.x_vec, self.y_vec, self.z_vec, self.f_vec, ind_norm)
            out = out.reshape(*shape,self.channels)
        else:
            out = compute_tensorf_val(
                    self.xy_plane, self.xz_plane, self.yz_plane,
                    self.x_vec, self.y_vec, self.z_vec, ind_norm)
            out = out.reshape(*shape)
        return out

    def scale_volume_grid(self, new_world_size):
        if self.channels == 0:
            return
        X, Y, Z = new_world_size
        X = X*self.scale
        Y = Y*self.scale
        Z = Z*self.scale
        self.xy_plane = nn.Parameter(F.interpolate(self.xy_plane.data, size=[X,Y], mode='bilinear', align_corners=True))
        self.xz_plane = nn.Parameter(F.interpolate(self.xz_plane.data, size=[X,Z], mode='bilinear', align_corners=True))
        self.yz_plane = nn.Parameter(F.interpolate(self.yz_plane.data, size=[Y,Z], mode='bilinear', align_corners=True))
        self.x_vec = nn.Parameter(F.interpolate(self.x_vec.data, size=[X,1], mode='bilinear', align_corners=True))
        self.y_vec = nn.Parameter(F.interpolate(self.y_vec.data, size=[Y,1], mode='bilinear', align_corners=True))
        self.z_vec = nn.Parameter(F.interpolate(self.z_vec.data, size=[Z,1], mode='bilinear', align_corners=True))

    def total_variation_add_grad(self, wx, wy, wz, dense_mode):
        '''Add gradients by total variation loss in-place'''
        loss = wx * F.smooth_l1_loss(self.xy_plane[:,:,1:], self.xy_plane[:,:,:-1], reduction='sum') +\
               wy * F.smooth_l1_loss(self.xy_plane[:,:,:,1:], self.xy_plane[:,:,:,:-1], reduction='sum') +\
               wx * F.smooth_l1_loss(self.xz_plane[:,:,1:], self.xz_plane[:,:,:-1], reduction='sum') +\
               wz * F.smooth_l1_loss(self.xz_plane[:,:,:,1:], self.xz_plane[:,:,:,:-1], reduction='sum') +\
               wy * F.smooth_l1_loss(self.yz_plane[:,:,1:], self.yz_plane[:,:,:-1], reduction='sum') +\
               wz * F.smooth_l1_loss(self.yz_plane[:,:,:,1:], self.yz_plane[:,:,:,:-1], reduction='sum') +\
               wx * F.smooth_l1_loss(self.x_vec[:,:,1:], self.x_vec[:,:,:-1], reduction='sum') +\
               wy * F.smooth_l1_loss(self.y_vec[:,:,1:], self.y_vec[:,:,:-1], reduction='sum') +\
               wz * F.smooth_l1_loss(self.z_vec[:,:,1:], self.z_vec[:,:,:-1], reduction='sum')
        loss /= 6
        loss.backward()

    def get_dense_grid(self):
        if self.channels > 1:
            feat = torch.cat([
                torch.einsum('rxy,rz->rxyz', self.xy_plane[0], self.z_vec[0,:,:,0]),
                torch.einsum('rxz,ry->rxyz', self.xz_plane[0], self.y_vec[0,:,:,0]),
                torch.einsum('ryz,rx->rxyz', self.yz_plane[0], self.x_vec[0,:,:,0]),
            ])
            grid = torch.einsum('rxyz,rc->cxyz', feat, self.f_vec)[None]
        else:
            grid = torch.einsum('rxy,rz->xyz', self.xy_plane[0], self.z_vec[0,:,:,0]) + \
                   torch.einsum('rxz,ry->xyz', self.xz_plane[0], self.y_vec[0,:,:,0]) + \
                   torch.einsum('ryz,rx->xyz', self.yz_plane[0], self.x_vec[0,:,:,0])
            grid = grid[None,None]
        return grid

    def extra_repr(self):
        return f'channels={self.channels}, world_size={self.world_size.tolist()}, n_comp={self.channels //3}'


def compute_tensorf_feat(xy_plane, xz_plane, yz_plane, x_vec, y_vec, z_vec, f_vec, ind_norm):
    # Interp feature (feat shape: [n_pts, n_comp])
    xy_feat = F.grid_sample(xy_plane, ind_norm[:,:,:,[1,0]], mode='bilinear', align_corners=True).flatten(0,2).T
    xz_feat = F.grid_sample(xz_plane, ind_norm[:,:,:,[2,0]], mode='bilinear', align_corners=True).flatten(0,2).T
    yz_feat = F.grid_sample(yz_plane, ind_norm[:,:,:,[2,1]], mode='bilinear', align_corners=True).flatten(0,2).T
    x_feat = F.grid_sample(x_vec, ind_norm[:,:,:,[3,0]], mode='bilinear', align_corners=True).flatten(0,2).T
    y_feat = F.grid_sample(y_vec, ind_norm[:,:,:,[3,1]], mode='bilinear', align_corners=True).flatten(0,2).T
    z_feat = F.grid_sample(z_vec, ind_norm[:,:,:,[3,2]], mode='bilinear', align_corners=True).flatten(0,2).T
    # Aggregate components
    feat = torch.cat([
        xy_feat * z_feat,
        xz_feat * y_feat,
        yz_feat * x_feat,
    ], dim=-1)
    feat = torch.mm(feat, f_vec)
    return feat


def compute_planes_feat(xy_plane, xz_plane, yz_plane, ind_norm):
    # Interp feature (feat shape: [n_pts, n_comp])
    xy_feat = F.grid_sample(xy_plane, ind_norm[:,:,:,[1,0]], mode='bilinear', align_corners=True).flatten(0,2).T
    xz_feat = F.grid_sample(xz_plane, ind_norm[:,:,:,[2,0]], mode='bilinear', align_corners=True).flatten(0,2).T
    yz_feat = F.grid_sample(yz_plane, ind_norm[:,:,:,[2,1]], mode='bilinear', align_corners=True).flatten(0,2).T

    # Aggregate components
    #feat = torch.cat([
    #    xy_feat ,
    #    xz_feat ,
    #    yz_feat 
    #], dim=-1)

    feat = xy_feat*xz_feat*yz_feat

    return feat

def compute_tensorf_feat_directional(xy_plane, xz_plane, yz_plane, ind_norm):
    # Interp feature (feat shape: [n_pts, n_comp])
    ind_norm_n = ind_norm.unsqueeze(0)
    #ipdb.set_trace()
    xy_feat = F.grid_sample(xy_plane, ind_norm_n[:,:,:,:,[5,1,0]], mode='bilinear', align_corners=True).flatten(0,2).T
    xz_feat = F.grid_sample(xz_plane, ind_norm_n[:,:,:,:,[4,2,0]], mode='bilinear', align_corners=True).flatten(0,2).T
    yz_feat = F.grid_sample(yz_plane, ind_norm_n[:,:,:,:,[3,2,1]], mode='bilinear', align_corners=True).flatten(0,2).T

    #Aggregate components
    feat = torch.cat([
        xy_feat ,
        xz_feat ,
        yz_feat 
    ], dim=-1)

    #feat = xy_feat*xz_feat*yz_feat

    return feat

def compute_tensorf_val(xy_plane, xz_plane, yz_plane, x_vec, y_vec, z_vec, ind_norm):
    # Interp feature (feat shape: [n_pts, n_comp])
    xy_feat = F.grid_sample(xy_plane, ind_norm[:,:,:,[1,0]], mode='bilinear', align_corners=True).flatten(0,2).T
    xz_feat = F.grid_sample(xz_plane, ind_norm[:,:,:,[2,0]], mode='bilinear', align_corners=True).flatten(0,2).T
    yz_feat = F.grid_sample(yz_plane, ind_norm[:,:,:,[2,1]], mode='bilinear', align_corners=True).flatten(0,2).T
    x_feat = F.grid_sample(x_vec, ind_norm[:,:,:,[3,0]], mode='bilinear', align_corners=True).flatten(0,2).T
    y_feat = F.grid_sample(y_vec, ind_norm[:,:,:,[3,1]], mode='bilinear', align_corners=True).flatten(0,2).T
    z_feat = F.grid_sample(z_vec, ind_norm[:,:,:,[3,2]], mode='bilinear', align_corners=True).flatten(0,2).T
    # Aggregate components
    feat = (xy_feat * z_feat).sum(-1) + (xz_feat * y_feat).sum(-1) + (yz_feat * x_feat).sum(-1)
    return feat


''' Mask grid
It supports query for the known free space and unknown space.
'''
class MaskGrid(nn.Module):
    def __init__(self, path=None, mask_cache_thres=None, mask=None, xyz_min=None, xyz_max=None):
        super(MaskGrid, self).__init__()
        if path is not None:
            st = torch.load(path)
            self.mask_cache_thres = mask_cache_thres
            density = F.max_pool3d(st['model_state_dict']['density.grid'], kernel_size=3, padding=1, stride=1)
            alpha = 1 - torch.exp(-F.softplus(density + st['model_state_dict']['act_shift']) * st['model_kwargs']['voxel_size_ratio'])
            mask = (alpha >= self.mask_cache_thres).squeeze(0).squeeze(0)
            xyz_min = torch.Tensor(st['model_kwargs']['xyz_min'])
            xyz_max = torch.Tensor(st['model_kwargs']['xyz_max'])
        else:
            mask = mask.bool()
            xyz_min = torch.Tensor(xyz_min)
            xyz_max = torch.Tensor(xyz_max)

        self.register_buffer('mask', mask)
        xyz_len = xyz_max - xyz_min
        self.register_buffer('xyz2ijk_scale', (torch.Tensor(list(mask.shape)) - 1) / xyz_len)
        self.register_buffer('xyz2ijk_shift', -xyz_min * self.xyz2ijk_scale)

    @torch.no_grad()
    def forward(self, xyz):
        '''Skip know freespace
        @xyz:   [..., 3] the xyz in global coordinate.
        '''
        shape = xyz.shape[:-1]
        xyz = xyz.reshape(-1, 3)
        mask = render_utils_cuda.maskcache_lookup(self.mask, xyz, self.xyz2ijk_scale, self.xyz2ijk_shift)
        mask = mask.reshape(shape)
        return mask

    def extra_repr(self):
        return f'mask.shape=list(self.mask.shape)'

